/*
 * Copyright 2024-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.cloud.ai.graph.agent.hook.messages;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.agent.hook.Hook;
import com.alibaba.cloud.ai.graph.state.ReplaceAllWith;

import org.springframework.ai.chat.messages.Message;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

public abstract class MessagesAgentHook implements Hook {
	private String agentName;

	public AgentCommand beforeAgent(List<Message> previousMessages, RunnableConfig config) {
		return new AgentCommand(previousMessages);
	}

	public AgentCommand afterAgent(List<Message> previousMessages, RunnableConfig config) {
		return new AgentCommand(previousMessages);
	}

	public void setAgentName(String agentName) {
		this.agentName = agentName;
	}

	public String getAgentName() {
		return agentName;
	}

	/**
	 * Creates a BeforeAgentAction instance for the given MessagesAgentHook.
	 * @param hook the MessagesAgentHook instance to proxy
	 * @return a BeforeAgentAction instance
	 */
	public static BeforeAgentAction beforeAgentAction(MessagesAgentHook hook) {
		return new BeforeAgentAction(hook);
	}

	/**
	 * Creates an AfterAgentAction instance for the given MessagesAgentHook.
	 * @param hook the MessagesAgentHook instance to proxy
	 * @return an AfterAgentAction instance
	 */
	public static AfterAgentAction afterAgentAction(MessagesAgentHook hook) {
		return new AfterAgentAction(hook);
	}

	/**
	 * Internal static class that proxies MessagesAgentHook and implements
	 * AsyncNodeActionWithConfig interface.
	 */
	public static class BeforeAgentAction implements AsyncNodeActionWithConfig {
		private final MessagesAgentHook messagesAgentHook;

		public BeforeAgentAction(MessagesAgentHook messagesAgentHook) {
			this.messagesAgentHook = messagesAgentHook;
		}

		@Override
		public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
			@SuppressWarnings("unchecked")
			List<Message> messages = (List<Message>) state.value("messages").orElse(List.of());

			AgentCommand command = messagesAgentHook.beforeAgent(messages, config);

			Map<String, Object> result = new HashMap<>();
			if (command.getMessages() != null) {
				if (UpdatePolicy.REPLACE == command.getUpdatePolicy()) {
					result.put("messages", ReplaceAllWith.of(command.getMessages()));
				} else {
					result.put("messages", command.getMessages());
				}
			}
			if (command.getJumpTo() != null) {
				result.put("jump_to", command.getJumpTo().name());
			}

			return CompletableFuture.completedFuture(result);
		}
	}

	/**
	 * Internal static class that proxies MessagesAgentHook and implements
	 * AsyncNodeActionWithConfig interface for afterAgent hook.
	 */
	public static class AfterAgentAction implements AsyncNodeActionWithConfig {
		private final MessagesAgentHook messagesAgentHook;

		public AfterAgentAction(MessagesAgentHook messagesAgentHook) {
			this.messagesAgentHook = messagesAgentHook;
		}

		@Override
		public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
			@SuppressWarnings("unchecked")
			List<Message> messages = (List<Message>) state.value("messages").orElse(List.of());

			AgentCommand command = messagesAgentHook.afterAgent(messages, config);

			Map<String, Object> result = new HashMap<>();
			if (command.getMessages() != null) {
				if (UpdatePolicy.REPLACE == command.getUpdatePolicy()) {
					result.put("messages", ReplaceAllWith.of(command.getMessages()));
				} else {
					result.put("messages", command.getMessages());
				}
			}
			if (command.getJumpTo() != null) {
				result.put("jump_to", command.getJumpTo().name());
			}

			return CompletableFuture.completedFuture(result);
		}
	}

}

